package gorm_v2

import (
	"chongwu/app/conf"
	"chongwu/app/library/my_errors"
	"errors"
	"fmt"
	"go.uber.org/zap"
	"gorm.io/driver/mysql"
	"gorm.io/driver/postgres"
	"gorm.io/driver/sqlserver"
	"gorm.io/gorm"
	"log"
	"strings"
)

func UseDbConn(database string, isMaster bool) *gorm.DB {
	var dbDialector gorm.Dialector
	if val, err := getDbDialector(database, isMaster); err != nil {
		log.Print(my_errors.ErrorDialectorDbInitFail, zap.Error(err))
	} else {
		dbDialector = val
	}
	if db, err := gorm.Open(dbDialector, &gorm.Config{}); err != nil {
		log.Print(my_errors.ErrorConnectDbFail, zap.Error(err))

		return nil
	} else {

		if conf.Conf.AppDebug {
			return db.Debug()
		}

		return db
	}
}

// 获取一个数据库方言(Dialector),通俗的说就是根据不同的连接参数，获取具体的一类数据库的连接指针
func getDbDialector(database string, isMaster bool) (gorm.Dialector, error) {
	var dbDialector gorm.Dialector
	DatabaseType, dsn := getDsn(database, isMaster)
	switch strings.ToLower(DatabaseType) {
	case "mysql":
		dbDialector = mysql.Open(dsn)
	case "sqlserver", "mssql":
		dbDialector = sqlserver.Open(dsn)
	case "postgres", "postgresql", "postgre":
		dbDialector = postgres.Open(dsn)
	default:
		return nil, errors.New(my_errors.ErrorDbDriverNotExists + DatabaseType)
	}
	return dbDialector, nil
}

func getDsn(database string, isMaster bool) (string, string) {
	var Username, Password, Host, Database, Charset, DatabaseType string
	var Port int64
	var dbSelect string
	if isMaster {
		dbSelect = "master"
	} else {
		dbSelect = "slave"
	}

	Host = conf.DbConf[dbSelect][database].Host
	Port = conf.DbConf[dbSelect][database].Port
	Username = conf.DbConf[dbSelect][database].Username
	Password = conf.DbConf[dbSelect][database].Password
	Database = conf.DbConf[dbSelect][database].Database
	Charset = conf.DbConf[dbSelect][database].Charset
	DatabaseType = conf.DbConf[dbSelect][database].DatabaseType

	switch strings.ToLower(DatabaseType) {
	case "mysql":
		return DatabaseType, fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=True&loc=Local", Username, Password, Host, Port, Database, Charset)
	case "sqlserver", "mssql":
		return DatabaseType, fmt.Sprintf("server=%s;port=%d;database=%s;user id=%s;password=%s;encrypt=disable", Host, Port, Database, Username, Password)
	case "postgresql", "postgre", "postgres":
		return DatabaseType, fmt.Sprintf("host=%s port=%d dbname=%s user=%s password=%s sslmode=disable TimeZone=Asia/Shanghai", Host, Port, Database, Username, Password)
	}
	return DatabaseType, ""
}
